package com.micro.magupe.sharding.executor;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CopyOnWriteArrayList;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.micro.magupe.sharding.connection.ShardingConnection;
import com.micro.magupe.sharding.merge.result.MemoryQueryResult;
import com.micro.magupe.sharding.merge.result.QueryResult;
import com.micro.magupe.sharding.router.SQLRouterResult;
import com.micro.magupe.sharding.router.SQLRouterUnit;

public final class StatementExecutor {

	private final ShardingConnection connection;
	
	private final int resultSetType;
	
	private final Collection<ShardingExecuteGroup<StatementExecuteUnit>> executeGroups = new LinkedList<>();
	
    private final List<ResultSet> resultSets = new CopyOnWriteArrayList<>();
    
    private final List<Statement> statements = new LinkedList<>();
    
	public StatementExecutor(final ShardingConnection connection, int resultSetType) {
		this.connection = connection;
		this.resultSetType = resultSetType;
	}

	public int getResultSetType() {
		return resultSetType;
	}
	
	public ShardingConnection getConnection() {
		return connection;
	}
	
	public List<Statement> getStatements() {
		return statements;
	}
	
	public Collection<ShardingExecuteGroup<StatementExecuteUnit>> getExecuteGroups() {
		return executeGroups;
	}
	
	public List<ResultSet> getResultSets() {
		return resultSets;
	}
	
	public void init(SQLRouterResult routerResult) throws SQLException {
		getExecuteGroups().addAll(obtainExecuteGroups(routerResult.getRouterUnits()));
	}

	private Collection<? extends ShardingExecuteGroup<StatementExecuteUnit>> obtainExecuteGroups(
			Collection<SQLRouterUnit> routerUnits) throws SQLException {
		Map<String, List<String>> sqlUnitGroups = getSQLUnitGroups(routerUnits);
        Collection<ShardingExecuteGroup<StatementExecuteUnit>> result = new LinkedList<>();
        for (Entry<String, List<String>> entry : sqlUnitGroups.entrySet()) {
            result.add(getSQLExecuteGroups(entry.getKey(), entry.getValue()));
        }
        return result;
	}

	private ShardingExecuteGroup<StatementExecuteUnit> getSQLExecuteGroups(String dataSourceName, List<String> sqls) throws SQLException {
		List<StatementExecuteUnit> list = new ArrayList<StatementExecuteUnit>();
		List<Connection> connections = getConnection().getConnections(dataSourceName, 1);
		
		for (String sql : sqls) {
			list.add(new StatementExecuteUnit(new SQLRouterUnit(dataSourceName, sql), connections.get(0).createStatement()));
		}
		return new ShardingExecuteGroup<StatementExecuteUnit>(list);
	}

	private Map<String, List<String>> getSQLUnitGroups(Collection<SQLRouterUnit> routerUnits) {
        Map<String, List<String>> result = new LinkedHashMap<>(routerUnits.size(), 1);
        routerUnits.forEach((routerUnit) -> {
            if (!result.containsKey(routerUnit.getDataSourceName())) {
                result.put(routerUnit.getDataSourceName(), new LinkedList<String>());
            }
            result.get(routerUnit.getDataSourceName()).add(routerUnit.getSql());
        });
        
        return result;
	}

	public List<QueryResult> executeQuery() throws SQLException {
		List<QueryResult> list = new ArrayList<QueryResult>();
		for (ShardingExecuteGroup<StatementExecuteUnit> executeGroup : executeGroups) {
			for (StatementExecuteUnit statementExecuteUnit : executeGroup.getInputs()) {
				list.add(getQueryResult(statementExecuteUnit));
			} 
		}
		
		return list;
	}
	
    private QueryResult getQueryResult(final StatementExecuteUnit statementExecuteUnit) throws SQLException {
        ResultSet resultSet = statementExecuteUnit.getStatement().executeQuery(statementExecuteUnit.getRouterUnit().getSql());
        getResultSets().add(resultSet);
        return new MemoryQueryResult(resultSet);
    }

    protected final void cacheStatements() {
        for (ShardingExecuteGroup<StatementExecuteUnit> each : executeGroups) {
            statements.addAll(Lists.transform(each.getInputs(), new Function<StatementExecuteUnit, Statement>() {
                @Override
                public Statement apply(final StatementExecuteUnit input) {
                    return input.getStatement();
                }
            }));
        }
    }
}
